{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from dotenv import load_dotenv\n",
    "\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "llm = ChatOpenAI(model=\"qwen-max-latest\", temperature=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.prompts import PromptTemplate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 基础少样本学习"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def few_shot_sentiment_classification(input_text):\n",
    "    few_shot_prompt = PromptTemplate(\n",
    "        input_variables=[\"input_text\"],\n",
    "        template=\"\"\"\n",
    "        将情绪分类为正面、负面或中性。\n",
    "        \n",
    "        示例：\n",
    "        文本：我喜欢这个产品！它太棒了。\n",
    "        情绪：正面\n",
    "        \n",
    "        文本：这部电影很糟糕。我讨厌它。\n",
    "        情绪：负面\n",
    "\n",
    "        文本：今天的天气还可以。\n",
    "        情绪：中性\n",
    "\n",
    "        现在，对以下内容进行分类：\n",
    "        文本：{input_text}\n",
    "        情绪：\n",
    "        \"\"\",\n",
    "    )\n",
    "\n",
    "    chain = few_shot_prompt | llm\n",
    "    result = chain.invoke({\"input_text\": input_text}).content\n",
    "\n",
    "    # Clean up the result\n",
    "    result = result.strip()\n",
    "    # Extract only the sentiment label\n",
    "    if \":\" in result:\n",
    "        result = result.split(\":\")[1].strip()\n",
    "\n",
    "    return result  # This will now return just \"Positive\", \"Negative\", or \"Neutral\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "输入: 这家新餐厅这么棒！你敢信？\n",
      "预测情绪: 正面\n"
     ]
    }
   ],
   "source": [
    "test_text = \"这家新餐厅这么棒！你敢信？\"\n",
    "result = few_shot_sentiment_classification(test_text)\n",
    "print(f\"输入: {test_text}\")\n",
    "print(f\"预测情绪: {result}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 高级少样本技术"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multi_task_few_shot(input_text, task):\n",
    "    few_shot_prompt = PromptTemplate(\n",
    "        input_variables=[\"input_text\", \"task\"],\n",
    "        template=\"\"\"\n",
    "        对给定的文本执行指定的任务。\n",
    "\n",
    "        示例：\n",
    "        文本：我喜欢这个产品！它太棒了。\n",
    "        任务：情绪\n",
    "        结果：积极\n",
    "\n",
    "        文本：你好，你好吗？\n",
    "        任务：语言\n",
    "        结果：英语\n",
    "\n",
    "        现在，执行以下任务：\n",
    "        文本：{input_text}\n",
    "        任务：{task}\n",
    "        结果：\n",
    "        \"\"\",\n",
    "    )\n",
    "\n",
    "    chain = few_shot_prompt | llm\n",
    "    return chain.invoke({\"input_text\": input_text, \"task\": task}).content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "积极\n",
      "英语\n"
     ]
    }
   ],
   "source": [
    "print(multi_task_few_shot(\"我真不敢相信这有多棒！\", \"情绪\"))\n",
    "print(multi_task_few_shot(\"I can't believe how great this is!\", \"语言\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 上下文学习"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def in_context_learning(task_description, examples, input_text):\n",
    "    example_text = \"\".join(\n",
    "        [f\"输入: {e['input']}\\n输出: {e['output']}\\n\\n\" for e in examples]\n",
    "    )\n",
    "\n",
    "    in_context_prompt = PromptTemplate(\n",
    "        input_variables=[\"task_description\", \"examples\", \"input_text\"],\n",
    "        template=\"\"\"\n",
    "        任务: {task_description}\n",
    "        \n",
    "        示例:\n",
    "        {examples}\n",
    "        \n",
    "        现在，根据以下输入执行任务：\n",
    "        输入: {input_text}\n",
    "        输出:\n",
    "        \"\"\",\n",
    "    )\n",
    "\n",
    "    chain = in_context_prompt | llm\n",
    "    return chain.invoke(\n",
    "        {\n",
    "            \"task_description\": task_description,\n",
    "            \"examples\": example_text,\n",
    "            \"input_text\": input_text,\n",
    "        }\n",
    "    ).content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "输入: python\n",
      "输出: ythonpay\n"
     ]
    }
   ],
   "source": [
    "task_desc = \"将给定的文本翻译成猪拉丁。直接输出最终结果。\"\n",
    "examples = [\n",
    "    {\"input\": \"hello\", \"output\": \"ellohay\"},\n",
    "    {\"input\": \"apple\", \"output\": \"appleay\"},\n",
    "]\n",
    "test_input = \"python\"\n",
    "\n",
    "result = in_context_learning(task_desc, examples, test_input)\n",
    "print(f\"输入: {test_input}\")\n",
    "print(f\"输出: {result}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 最佳实践与评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(model_func, test_cases):\n",
    "    \"\"\"\n",
    "    Evaluate the model on a set of test cases.\n",
    "\n",
    "    Args:\n",
    "    model_func: The function that makes predictions.\n",
    "    test_cases: A list of dictionaries, where each dictionary contains an \"input\" text and a \"label\" for the input.\n",
    "\n",
    "    Returns:\n",
    "    The accuracy of the model on the test cases.\n",
    "    \"\"\"\n",
    "    correct = 0\n",
    "    total = len(test_cases)\n",
    "\n",
    "    for case in test_cases:\n",
    "        input_text = case[\"input\"]\n",
    "        true_label = case[\"label\"]\n",
    "        prediction = model_func(input_text).strip()\n",
    "\n",
    "        is_correct = prediction.lower() == true_label.lower()\n",
    "        correct += int(is_correct)\n",
    "\n",
    "        print(f\"输入: {input_text}\")\n",
    "        print(f\"预测值: {prediction}\")\n",
    "        print(f\"真实值: {true_label}\")\n",
    "        print(f\"是否正确: {is_correct}\\n\")\n",
    "\n",
    "    accuracy = correct / total\n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "输入: 这个产品超出了我的预期！\n",
      "预测值: 正面\n",
      "真实值: 正面\n",
      "是否正确: True\n",
      "\n",
      "输入: 我对这项服务非常失望。\n",
      "预测值: 负面\n",
      "真实值: 负面\n",
      "是否正确: True\n",
      "\n",
      "输入: 今天最高气温23度\n",
      "预测值: 中性\n",
      "真实值: 中性\n",
      "是否正确: True\n",
      "\n",
      "模型准确率: 1.00\n"
     ]
    }
   ],
   "source": [
    "test_cases = [\n",
    "    {\"input\": \"这个产品超出了我的预期！\", \"label\": \"正面\"},\n",
    "    {\"input\": \"我对这项服务非常失望。\", \"label\": \"负面\"},\n",
    "    {\"input\": \"今天最高气温23度\", \"label\": \"中性\"},\n",
    "]\n",
    "\n",
    "accuracy = evaluate_model(few_shot_sentiment_classification, test_cases)\n",
    "print(f\"模型准确率: {accuracy:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "prompt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
